
import numpy as np
from channel import *
from simSweep import *
from timeit import default_timer as timer
from datetime import timedelta
import matplotlib.pyplot as plt
import sys
#import device
import os
import time

import torch
from torch import nn
import argparse
import pickle
from torchsummary import *


#--------------------------------#
from FwdBwdNeuralEqV3 import *
from neuralEQRNN import *
from simNeuralEQRNN import *
from simNeuralEQ import *
from Tx	import *
from eq	import *
from neuralEQ import *
from functions import *
import device

#--------------TPE---------------#
import hyperopt
from hyperopt import tpe, STATUS_OK, Trials, hp, fmin

def Objective(params):
    if __name__	== "__main__":
        '''*********************************************
        Neural EQ training for various SNR and misc.
        Training is	performed for snrTrainList,	lossFn,	and	simpleDataTraining.
        If you want	to reduce or increase sweep	cases, modify here.(It's not controlled	by config file now.	)
        *********************************************'''
        #*************************HEADER***********************#
        #device = "cuda" if torch.cuda.is_available() else "cpu"
        startTime =	time.time()
        np.random.seed(1)
        args = parsing_def()
        sys.path.insert(0, './config')
        config_module =	__import__('config_{}'.format(args.config))
        cfg= config_module.config
        # mod
        if cfg['tpe']['mod'] == 'nrz':
            modNum = 2
        elif cfg['tpe']['mod'] ==	'pam4':
            modNum = 4
        elif cfg['tpe']['mod'] ==	'pam8':
            modNum = 8
        else:
            sys.exit('invalid modulation')


        delay =	int((cfg['tpe']['inSize'])/4)
        if(cfg['tpe']['delayOffset'] is not None):
            delayOffset = cfg['tpe']['delayOffset']
        else:
            delayOffset = -list(cfg['tpe']['chSBR']).index(max(cfg['tpe']['chSBR']))
            #print(f"Calculated delay offset is {delayOffset}")
        #******************************************************#
        
        def saveList(fileName, l):
            with open(fileName, "wb") as fp:
                pickle.dump(l, fp)

        def loadList(fileName):
            with open(fileName, "rb") as fp:
                out = pickle.load(fp)
            return out

        '''*****************************************************
        Tx and channel define.
        Tx generates random	data according to modulation.
        Channel	adds ISI and noise.
        Channel	is defined 3 times for training, validation	and	test.
        Note that test sets	are	used for both nEQ test and normal equalizer.
        *****************************************************'''

        tx = Tx(mod=cfg['tpe']['mod'])

        #@@	Valid sequence for on training 
        chInValid =	tx.run(int(cfg['tpe']['dataSizeValid']))
        chValid	= Channel(sbr=cfg['tpe']['chSBR'], snr=cfg['tpe']['snrValid'])
        chOutValid = chValid.run(chIn =	chInValid, flagN=cfg['tpe']['noiseFlag'])

        #@@	Test sequence for final	evaluation 
        chInTest = tx.run(int(cfg['tpe']['dataSizeTest']))
        chTest = Channel(sbr=cfg['tpe']['chSBR'],	snr=cfg['tpe']['snrTest'])
        chOutTest =	chTest.run(chIn	= chInTest,	flagN=cfg['tpe']['noiseFlag'])

        if cfg['tpe']['useFwdBwdNeuralEq']:
            N_size = params['N']
            print("------Architecture------", flush=True)
            print(f"Chosen N Size: {N_size}")
            print("------------------------", flush=True)
                    
        else :             
            Hidden_Size_1 = params['Hidden_Size_1']
            Hidden_Size_2 = params['Hidden_Size_2']
            #Hidden_Size_3 = params['Hidden_Size_3']

            print("--------------------------Architecture--------------------------", flush=True)	
            print(f"Chosen Layer Structure: [{Hidden_Size_1}, {Hidden_Size_2}]") 
            print(f"Initiallizing Model: [{cfg['tpe']['inSize']} - {Hidden_Size_1} - {Hidden_Size_2} - {cfg['tpe']['outSize'] * modNum}]")			
            print("----------------------------------------------------------------", flush=True)

        for	selTrainData in	cfg['tpe']['selTrainDataList']:
            for	lossFn in cfg['tpe']['lossFnList']:
                #print ("")
                #print ("")
                #print (f"selTrainData: {selTrainData}")
                #print (f"lossFn: {lossFn}")
                for	idx, snrTrain in enumerate(cfg['tpe']['snrTrainList']):
                    #print("")
                    #print(f"trainIdx: {idx}	\t snrTrain: {snrTrain}")
                    #print("")
                    #@@	Neural network definition
                    #@@	nrzNnOutOne	means network output size is set to	1 for NRZ. But it seems	not	work.
                    if	cfg['tpe']['useFwdBwdNeuralEq']:
                        nEQ	= FwdBwdNeuralEq(
                                        cfg['tpe']['hiddenStage'], 
                                        cfg['tpe']['inSize'],
                                        delay, 
                                        N_size, 
                                        cfg['tpe']['batchSize'], 
                                        cfg['tpe']['mod'],
                                        )
                    else:
                        nEQ	= neuralEQ(
                                    params,
                                    inSize=cfg['tpe']['inSize'], 
                                    outSize=cfg['tpe']['outSize'] * modNum, 
                                    mod=cfg['tpe']['mod'], 
                                    nnSel=0
                                    )
                    nEQ	= nEQ.to(device.device)

                    #@@	Optimizer definition.
                    #@@	Adam is	selected.
                    #opt = torch.optim.SGD(nEQ.parameters(), lr=lrInit)
                    opt	= torch.optim.Adam(
                                    nEQ.parameters(), 
                                    lr=cfg['tpe']['lr'], 
                                    weight_decay=cfg['tpe']['weightDecay'])#1e-5)

                    #@@	Scheduler definition.
                    #@@	gamma=1	means no learning rate change.
                    #sch = torch.optim.lr_scheduler.StepLR(opt,	step_size=stepSize,	gamma=gamma)
                    #print("")
                    #print("----------------NeuralNet parameter----------------")
                    #print(nEQ)
                    #print(lossFn)
                    #print(opt)
                    #print("---------------------------------------------------")
                    #print("")
                
                    summary(
                            nEQ, 
                            (cfg['tpe']['batchSize'],cfg['tpe']['inSize']),	
                            batch_size=cfg['tpe']['batchSize'], 
                            device=device.device
                            )
            
                    #@@	Check if pre-simulated fwdBwd is exists
                    #@@	If corresponding file(snr, sbr,	mod	...) exists, just load from	file.
                    #@@	If not,	run	fwdBwd algorithm 
                    if cfg['tpe']['forceTrainIn']:
                        fwdBwdProbFileName = 'caching_data/probNew_less09.list'
                        fwdBwdProbChOutFileName	= 'caching_data/chOutNew_less09.list'
                        fwdBwdProbChInFileName = 'caching_data/chInNew_less09.list'
                    else:
                        if (cfg['tpe']['mismatchSNR']	is not None):
                            fwdBwdProbFileName = './caching_data/%s_fwdBwdProb_size%d_%s_snr%ddB.list'%(
                                                                                                    cfg['tpe']['mod'],
                                                                                                    cfg['tpe']['dataSizeTrain'],
                                                                                                    cfg['tpe']['eqSBR'],
                                                                                                    snrTrain+cfg['tpe']['mismatchSNR']
                                                                                                    )
                            fwdBwdProbChOutFileName	= './caching_data/%s_fwdBwdProbChOut_size%d_%s_snr%ddB.list'%(
                                                                                                    cfg['tpe']['mod'],
                                                                                                    cfg['tpe']['dataSizeTrain'],
                                                                                                    cfg['tpe']['eqSBR'],
                                                                                                    snrTrain+cfg['tpe']['mismatchSNR']
                                                                                                    )
                            fwdBwdProbChInFileName = './caching_data/%s_fwdBwdProbChIn_size%d_%s_snr%ddB.list'%(
                                                                                                    cfg['tpe']['mod'],
                                                                                                    cfg['tpe']['dataSizeTrain'],
                                                                                                    cfg['tpe']['eqSBR'],
                                                                                                    snrTrain+cfg['tpe']['mismatchSNR']
                                                                                                    )

                        else:
                            fwdBwdProbFileName = './caching_data/%s_fwdBwdProb_size%d_%s_snr%ddB.list'%(
                                                                                                    cfg['tpe']['mod'],
                                                                                                    cfg['tpe']['dataSizeTrain'],
                                                                                                    cfg['tpe']['eqSBR'],
                                                                                                    snrTrain,
                                                                                                    )
                            fwdBwdProbChOutFileName	= './caching_data/%s_fwdBwdProbChOut_size%d_%s_snr%ddB.list'%(
                                                                                                    cfg['tpe']['mod'],
                                                                                                    cfg['tpe']['dataSizeTrain'],
                                                                                                    cfg['tpe']['eqSBR'],
                                                                                                    snrTrain,
                                                                                                    )
                            fwdBwdProbChInFileName = './caching_data/%s_fwdBwdProbChIn_size%d_%s_snr%ddB.list'%(
                                                                                                    cfg['tpe']['mod'],
                                                                                                    cfg['tpe']['dataSizeTrain'],
                                                                                                    cfg['tpe']['eqSBR'],
                                                                                                    snrTrain,
                                                                                                    )
                    if (os.path.exists(fwdBwdProbFileName)):
                        #@@	Existing case. Load	from the file.
                        print("")
                        print("File(%s)	exists,	load from file"%fwdBwdProbFileName)
                        print("")
                        fwdBwdProbTrain	= loadList(fwdBwdProbFileName)
                        #fwdBwdProbChInTrain = loadList(fwdBwdProbChInFileName)
                        chOutTrain = loadList(fwdBwdProbChOutFileName)
                        chInTrain =	loadList(fwdBwdProbChInFileName)
                    else:

                        #@@	Train sequence gen
                        chInTrain =	tx.run(int(cfg['tpe']['dataSizeTrain']))
                        if (cfg['tpe']['mismatchSNR']	is not None):
                            ch = Channel(sbr=cfg['tpe']['chSBR'],	snr=snrTrain+cfg['tpe']['mismatchSNR'])
                        else:
                            ch = Channel(sbr=cfg['tpe']['chSBR'],	snr=snrTrain)
                        chOutTrain = ch.run(chIn = chInTrain, flagN=cfg['tpe']['noiseFlag'])

                    
                        if cfg['tpe']['selTrainDataList']	== 0:
                            #@@	No existing	case. run fwdBwd
                            print("")
                            print("File(%s)	no exists, excute fwdBwd"%fwdBwdProbFileName)
                            print("")
                            #@@	Running	fwdBwd with	specified channel output, chOutTrain.
                            sweepForTrain =	simSweep(
                                                chSbr=cfg['tpe']['chSBR'], 
                                                eqSbr=cfg['tpe']['eqSBR'], 
                                                snrList=[snrTrain],	
                                                originData=chInTrain, 
                                                chOutList=[chOutTrain],	
                                                mod=cfg['tpe']['mod'], 
                                                stateGen=True
                                                )
                            fwdBwdBerListTrain,	fwdBwdProbTrain	= sweepForTrain.fwdBwd(fwdBwdLen=['tpe']['inSize'])
                            saveList(fwdBwdProbFileName, fwdBwdProbTrain)
                            saveList(fwdBwdProbChOutFileName, chOutTrain)
                            saveList(fwdBwdProbChInFileName, chInTrain)
                
                    if cfg['tpe']['selTrainDataList']	== 0:
                        fwdBwdProbTrain	= np.array(fwdBwdProbTrain)
                    #print (fwdBwdProbTrain.shape)
                    #@@	Post-processing	fwdBwd output according	to loss	function.
                    #@@	If simpleDataTraining=1, forcing fwdBwd	output to simple TX	data.
                    #@@	fwdBwdProbOut =	(modNum)*dataLen
                    #@@	TxData = (1)*dataLen. 
                    #@@	crossEntropy = (1)*dataLen
                    #@@	manualCrossEntropy = (modNum)*dataLen
                    #@@	mse	= (modNum)*dataLen
                    #@@	According to format	above, it need to be adjusted.

                    if cfg['tpe']['onTheFly']:
                        chInTrain =	None
                        chOutTrain = None

                    	
                    
                    trainLossHis, validLossHis,	berValidHis	= trainEval(
                                                                    nEQ,
                                                                    tx,
                                                                    chInValid,
                                                                    chOutValid,
                                                                    cfg['tpe']['numEpoch'],
                                                                    cfg['tpe']['evalFreq'],
                                                                    cfg['tpe']['mod'],
                                                                    cfg['tpe']['chSBR'],
                                                                    cfg['tpe']['inSize'],
                                                                    cfg['tpe']['outSize'],
                                                                    cfg['tpe']['batchSize'],
                                                                    delay+delayOffset,
                                                                    lossFn,
                                                                    opt,
                                                                    int(cfg['tpe']['dataSizeTrain']),
                                                                    snrTrain,
                                                                    cfg['tpe']['noiseFlag'],
                                                                    chInTrain,
                                                                    chOutTrain,
                                                                    )

        
                    #@@	After training,	neural network parameters are saved	for	each snrTrain.
                    torch.save(nEQ,	'./results/%s_TPE/nEQ_%s_%ddB_simp%d_%s.pt'%(
                                                                        args.name,
                                                                        cfg['tpe']['mod'],
                                                                        snrTrain,
                                                                        selTrainData,
                                                                        lossFn))
                    

                    #@@	Finally, running nEQ with test set.
                    simNEQ = simNeuralEQ(
                                        txDataTrain=None,
                                        rxDataTrain=None, 
                                        txDataTest=chInTest, 
                                        rxDataTest=chOutTest, 
                                        neuralEQ=nEQ, 
                                        mod=cfg['tpe']['mod']
                                        )

                    testLoss, berTest =	simNEQ.evalNeuralEQ(
                                                        lossFn,	
                                                        batchSize=cfg['tpe']['batchSize'], 
                                                        inSize=cfg['tpe']['inSize'], 
                                                        outSize=cfg['tpe']['outSize'], 
                                                        delay=delay+delayOffset, 
                                                        )
                    #Current_Best_Ber = 0
                    #if Current_Best_Ber < berTest:
                    #    Current_Best_Ber = berTest

                    berTestList	= [berTest]
                    #testLossList.append(testLoss)
                    #testBerList.append(berTest)
                    print("-----------------------Intermediate Result----------------------", flush=True)
                    print (f"selTrainData: {selTrainData}")
                    print (f"lossFn: {lossFn}")
                    print(f"Testber: {berTest:e}", flush=True)
                    # print(f"Current Best BER: {Current_Best_Ber:e}")
                    timeSim	= round((time.time()-startTime)/60.,1) #	Unit: minuite
                    print(f"Total simulation time: {timeSim} mins")
                    print("----------------------------------------------------------------\n\n", flush=True)


                    if cfg['tpe']['plotLoss']:
                        plt.figure()
                        plt.plot(trainLossHis,'-', label='trainloss')
                        #plt.plot(testLossList,'-',	label='testloss')
                        plt.grid(True)
                        #plt.yscale('log')
                        #plt.ylim([1e-9, 1])
                        plt.xlabel('epoch')
                        plt.ylabel('loss')
                        plt.legend(loc='best')
                        #plt.show()
                        plt.savefig('./results/%s_TPE/loss_%s_%ddB.png'%(
                                                                args.name,
                                                                cfg['tpe']['mod'],
                                                                snrTrain))
                        #plt.cla()
                        
                        firBer = None
                        dfeBer = None
                        if 0:
                            plt.figure()
                            #plt.plot(trainBerList,'-',	label='trainber')
                            plt.plot(berValidHis,'-', label='validber')
                            if (firBer is not None):
                                plt.plot(firBer*len(berValidHis),'--',label='firber')
                            if (dfeBer is not None):
                                plt.plot(dfeBer	*len(berValidHis),'--',label='dfeber')
                            plt.plot(berTestList*len(berValidHis),'--',label='nnFinalBer')
                            #print (dfeBerList*len(trainBerList))
                            plt.grid(True)
                            plt.yscale('log')
                            plt.ylim([1e-4,	1])
                            plt.xlabel('epoch')
                            plt.ylabel('ber(accuracy)')
                            #plt.show()
                            plt.legend(loc='best')
                            plt.savefig('./results/%s_TPE/ber_%s_%ddB.png'%(args.name,cfg['tpe']['mod'],snrTrain))
                            #plt.cla()
            
        # if cfg['tpe']['plotLoss']:
           # plt.show()       

        return {'loss': berTest, 'params': params, 'status': STATUS_OK}

def search_space():
    np.random.seed(1)
    args = parsing_def()
    sys.path.insert(0, './config')
    config_module =	__import__('config_{}'.format(args.config))
    cfg= config_module.config
    if cfg['tpe']['useFwdBwdNeuralEq']:
        space = {
            'N': hp.choice('N', np.arange(8, 129, 2, dtype = int))
        }
        evals = 40
    else :
        space = {
                'Hidden_Size_1': hp.choice('Hidden_Size_1', np.arange(64, 513, 8, dtype = int)),
                'Hidden_Size_2': hp.choice('Hidden_Size_2', np.arange(64, 513, 8, dtype = int)),
                #'Hidden_Size_3': hp.choice('Hidden_Size_3', [0]),
            }
        evals = 50
    return space, evals

## Config Setting
args = parsing_def()
sys.path.insert(0, './config')
config_module =	__import__('config_{}'.format(args.config))
cfg= config_module.config
delay = int((cfg['tpe']['inSize'])/4)
if(cfg['tpe']['delayOffset'] is not None):
    delayOffset = cfg['tpe']['delayOffset']
else:
    delayOffset = -list(cfg['tpe']['chSBR']).index(max(cfg['tpe']['chSBR']))
        
## Checking for individual candidate
if 0:
    print("----------------------------------------------------------------")
    print("Beginning Training Process...")
    if cfg['tpe']['useFwdBwdNeuralEq']:
        print("Initializing FwdBwd NeuralEQ")
    else :
        print("Initialzing FC NeuralEQ")
    print(f"Network Parameters:")
    print(f"Epoch : {cfg['tpe']['numEpoch']}")
    print(f"Learning Rate : {cfg['tpe']['lr']}")
    print(f"SBR : {cfg['tpe']['chSBR']}")
    print(f"Delay Offset : {delayOffset}")
    print("----------------------------------------------------------------")
    print("\n")

 
    params = {
        #"Hidden_Size_1": 272,
        #"Hidden_Size_2": 504,
        # "Hidden_Size_3": 0,
         "N": 51 
    }

    Objective(params)

## Running TPE Strategy
if 1: 
    space, evals = search_space()

    # start_up = 10
    # tpe_algorithm = tpe.suggest(n_startup_jobs = start_up)

    tpe_algorithm = tpe.suggest

    trials = Trials()
   
    print("----------------------------------------------------------------")
    print("Beginning TPE Process...")
    if cfg['tpe']['useFwdBwdNeuralEq']:
        print("Initializing FwdBwd NeuralEQ")
    else :
        print("Initialzing FC NeuralEQ")
    print(f"Network Parameters:")
    print(f"Epoch : {cfg['tpe']['numEpoch']}")
    print(f"Learning Rate : {cfg['tpe']['lr']}")
    print(f"SBR : {cfg['tpe']['chSBR']}")
    print(f"Delay Offset : {delayOffset}")
    print("----------------------------------------------------------------")
    print("\n")

    best = fmin(fn = Objective, space = space, algo = tpe.suggest, max_evals = evals, trials = trials)
    #print(f"{best}\n\n")

    best_trial = trials.best_trial

    print("--------------------------Final Result--------------------------")
    print('Best Hyperparameter: {}'.format(best_trial['misc']['vals']))
    print('Best Loss: {}'.format(best_trial['result']['loss']))
    print("----------------------------------------------------------------")

